"""run_experiment.py
This script runs an experiment with adversarial attacks and defenses on a given model and dataset. It allows for the configuration of various parameters through command line arguments or a configuration file.
Arguments for argparse:
Arguments to load and save configuration:
    --use_config: If set, the script will use a configuration file to load parameters. (default: False)
    --config: Path to the configuration file. Supported formats are YAML and JSON. (default: '')
    --save_config: If set, the script will save the current configuration to a file. (default: False)
Arguments for the experiment:
    --model_attack: The model used to create adversarial examples. Options are 'ResNet50' or 'Vit'. (default: 'ResNet50')
    --batch_size: Batch size for the experiment. (default: 16)
    --dataset: Dataset to use for the experiment. Options are 'imagenet', or 'imagenet_5000'. (default: 'imagenet_5000')
Arguments for the attack:
    --attack: The attack method to use. Options are 'FGSM', 'iFGSM', 'PGD', 'CW', or 'DeepFool'. (default: 'FGSM')
    --norm: Norm to use for the attack. Currently not used. (default: 'l_inf')
    --epsilons: List of epsilon values for the attack in the form X/255. For 'CW' and 'DeepFool', the /255 part is ignored. (default: ['8/255'])
    --attack_steps: Number of steps for the attack. (default: 10)
Arguments for the defense:
    --defense: The defense method to use. Options are 'jpeg', 'HiFiC', 'ELIC', 'CRDR', 'hyperprior. (default: None)
    --defense_param: List of parameters to use for the defense (e.g., quality for JPEG) in the form 'quality:25 iterations:5 ...'. check the defenses in defense/ for more information. (default: '')
    --attack_model_param: List of parameters to use for the attack model (e.g., quality for JPEG) in the form 'quality:25 iterations:5 ...'. check the defenses in defense/ for more information. (default: '')
    --non_diff: --Deprecated--If set, the defense uses the non-differentiable version. Only available for 'HiFiC'. (default: False)
    --attack_through: If set, the gradient is run through the defense when creating adversarial examples leading to a white-box attack. (default: False)
Arguments for the output:
    --output: Path to save the results. (default: 'results/')
    --get_baseline: If set, the result includes the baseline accuracy of the model. (default: False)


"""
import argparse
import json
import yaml
from experiment_base import Experiment
import torch
from defense.create_defense import create_defense
import pandas as pd
import numpy as np
import datetime
from util import load_model,load_testset
import os



def save_config(args, path):
    """Save the configuration to a yaml or json file"""
    with open(path, 'w') as file:
        if path.endswith('.yaml'):
            yaml.dump(vars(args), file)
        elif path.endswith('.json'):
            json.dump(vars(args), file,indent=6)
        else:
            raise ValueError("Unsupported file format. Use 'yaml' or 'json'.")


def load_config(path):
    """Load the configuration from a yaml or json file"""
    with open(path, 'r') as file:
        if path.endswith('.yaml'):
            config = yaml.load(file, Loader=yaml.FullLoader)
        elif path.endswith('.json'):
            config = json.load(file)
        else:
            raise ValueError("Unsupported file format. Use 'yaml' or 'json'.")
    return config

def assign_args_from_config(args, config):
    """Assign the args the corresponding value from the config"""
    for key, value in config.items():
        if hasattr(args, key):
            setattr(args, key, value)
        else:
            raise ValueError(f"Argument {key} not found in args")



def combine_models(defense,model):
    return torch.nn.Sequential(defense,model)

def save_results(results,epsilons, path):
    column_names = epsilons
    results = np.atleast_2d(np.array(results))
    results = pd.DataFrame(results,columns=column_names)
    results.to_csv(path+'.csv')

def create_attack_defense(model,defense_arg,defense_param,attack_param,attack_through,is_normalized = False,attack=None):
    file_suffix = ''

    if defense_arg is not None:
        # Create the defense model
        defense = create_defense(defense_arg,defense_param)
        if is_normalized:
            defense.set_normalization_used((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
        model_defense = combine_models(defense,model)
        file_suffix += defense_arg
    else:
        model_defense = None

    if attack_through:
        #create the attack model
        if not attack_param == '':
            adapted_defense = create_defense(defense_arg,attack_param)
            if is_normalized:
                adapted_defense.set_normalization_used((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
            model_attack = combine_models(adapted_defense,model)
        else:    
            if not attack in ['STBPDA','OBPDA','FSLBPDA']:
                try:
                    iterations = [n for n in defense_param if n.startswith('iterations')][0]
                    if int(iterations.split(':')[1]) > 1:
                        print('using the non-iterative version of the defense in the attack')
                        for i in range(len(defense_param)):
                            if defense_param[i].startswith('iterations'):
                                defense_param[i] = 'iterations:1'
                        adapted_defense = create_defense(defense_arg,defense_param)
                        if is_normalized:
                            adapted_defense.set_normalization_used((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                        model_attack = combine_models(adapted_defense,model)
                    else:
                        model_attack = model_defense
                except IndexError:
                    model_attack = model_defense
            else:
                model_attack = model_defense
        
            file_suffix += '_through'
    else:
        model_attack = model

    return model_attack, model_defense, file_suffix


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Run an adversarial experiment')
    #arguments to load and save congiguration
    parser.add_argument('--use_config', action='store_true', help='Use the config file')
    parser.add_argument('--config', type=str, help='Path to the config file', default='')
    parser.add_argument('--save_config',action='store_true', help='Save the config file')
    #arguments for the experiment
    parser.add_argument('--model_attack', type=str, help='Model used to create the adversarial example, ResNet50 or Vit', default='ResNet50')
    parser.add_argument('--batch_size', type=int, help='Batch size', default=16)
    parser.add_argument('--dataset', type=str, help='Dataset to use.', default='imagenet_5000')
    #arguments for the attack
    parser.add_argument('--attack', type=str, help='The attack to use. FGSM,iFGSM,PGD ...', default='PGD')
    parser.add_argument('--norm', type=str, help='Norm to use. This parameter is not used atm', default='l_inf')
    parser.add_argument('--epsilons',nargs='*', type=str, help='Epsilons for the attack in the form X/255.', default=['8/255'])
    parser.add_argument('--attack_steps', type=int, help='Number of steps for the attack', default=10)
    #arguments for the defense
    parser.add_argument('--defense', type=str, help='The defense to use. jpeg,HiFiC,ELIC,hyperprior,CRDR', default=None)
    parser.add_argument('--defense_param',nargs='*', type=str, help='Parameters to use for the defense (like quality for JPEG). check the defenses in defense/ for more information', default='')
    parser.add_argument('--attack_model_param',nargs='*', type=str, help='Parameters to use for the attack model (like quality for JPEG). check the defenses in defense/ for more information', default='')
    parser.add_argument('--non_diff',action='store_true', help='---Deprecated---, If set the defense uses the non differentiable version. only available for HiFiC')
    parser.add_argument('--attack_through',action='store_true', help='If true the gradient is run through the defense when creating adversarial examples leading to a white-box attack')
    #arguments for the output
    parser.add_argument('--output', type=str, help='Path to save the results', default='results/')
    parser.add_argument('--get_baseline',action='store_true', help='If True the result includes the baseline accuracy')
    
    args = parser.parse_args()

    if args.attack_through and args.defense is None:
        print('Cannot attack through  without a defense')
        args.attack_through = False

    #compute epsilons from string input
    epsilons = [float(n)/float(d) for [n,d] in [str_eps.split('/') for str_eps in args.epsilons]]


    # Load the configuration if use_config is True
    if args.use_config:
        config = load_config(args.config)
        assign_args_from_config(args, config)


    # Load the model
    model,preprocessing = load_model(args.model_attack,args.dataset)

    print('preprocessing:',preprocessing)
    
    # Create the testset

    testset = load_testset(args.dataset,preprocessing)

    # Create the models for attack and defense
    model_attack, model_defense, file_suffix = create_attack_defense(model=model,defense_arg=args.defense,defense_param=args.defense_param,attack_param=args.attack_model_param,attack_through=args.attack_through,attack =args.attack,is_normalized=False)
    

    # Create the experiment
    experiment = Experiment(model_attack=model_attack, testset=testset, model_defense=model_defense,dataset_name=args.dataset, epsilons=epsilons, batch_size=args.batch_size)

    
    
   
    # Run the experiment
    results = experiment.run_experiment(method=args.attack,steps = args.attack_steps)
    # Save the results
    time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    file_name=args.dataset + '_' + args.attack + '_' + file_suffix + time
        # save the configuration if save_config is True
    if args.save_config:
        if args.config == '':
            args.config = 'configs/'+file_name+'.json'
        save_config(args,args.config)
    # Get baseline accuracy
    if args.get_baseline:
        baseline_accuracy = experiment.get_predictions()
        results = [baseline_accuracy] + results
        args.epsilons = ['0']+args.epsilons
    save_results(results, args.epsilons, args.output+file_name)